# Copyright 2017 Bo Shao. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import re
import tensorflow as tf
import time

from settings import PROJECT_ROOT
from chatbot.botpredictor import BotPredictor

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


def test_demo():
    print("# Creating TF session ...")

    corp_dir = os.path.join(PROJECT_ROOT, 'Data', 'Corpus')
    knbs_dir = os.path.join(PROJECT_ROOT, 'Data', 'KnowledgeBase')
    res_dir = os.path.join(PROJECT_ROOT, 'Data', 'Result')

    test_dir = os.path.join(PROJECT_ROOT, 'Data', 'Test')
    in_file = os.path.join(test_dir, 'samples.txt')
    out_file = os.path.join(test_dir, 'responses.txt')

    with tf.Session() as sess:
        predictor = BotPredictor(sess, corpus_dir=corp_dir, knbase_dir=knbs_dir,
                                 result_dir=res_dir, result_file='basic')
        session_id = predictor.session_data.add_session()

        print("# Prediction started ...")
        t0 = time.time()
        with open(in_file, 'r') as f_in:
            with open(out_file, 'a') as f_out:
                f_out.write(get_header())
                for line in f_in:
                    sentence = line.strip()
                    if not sentence or sentence.startswith("#=="):
                        continue
                    f_out.write("> {}\n".format(sentence))
                    output = re.sub(r'_nl_|_np_', '\n', predictor.predict(session_id, sentence)).strip()
                    f_out.write("{}\n\n".format(output))

        t1 = time.time()
        print("# Prediction completed. Time spent on prediction: {:4.2f} seconds".format(t1-t0))


def get_header():
    desc = "# This file was generated by testdemo.py for testing purpose. It reads samples.txt " \
           "lines by line, and feeds each line to the predictor."
    return "{}\n# Command: python testdemo.py\n# Date and Time Generated: {}\n\n".\
        format(desc, time.strftime("%Y-%m-%d %H:%M"))


if __name__ == "__main__":
    test_demo()
